import os, sys, fnmatch
sys.path.append('./xds/xds_python/')

import torch
import math
import copy
import argparse
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from pathlib import Path

from config.dataset_config import get_dataset_config_class
from config.hyparams_config import get_hyparams_config_class
from preprocess.data_loader import load_monkey_spike, monkey_spike_generator, get_monkey_spike_size, get_all_input_spike, monkey_spike_transform
from preprocess.utils import spike_preprocessing, plot_cursor_position, get_pred_cursor_pos
from xds_utils import find_target_dir
from model.model import vanilla_model, VAE_Model, disentangle_VAE_Model, train_disentangle_VAE_Model
from model.cursor_predictor import cursor_Predictor, train_cursor_predictor, get_cursor_position_structure, fine_tune_cursor_decoder
from model.lyapunov_spectrum import lyapunov_solve_unknown, extract_cell_states, extract_semantic_cell_state
from decoder.LSTM_decoder import vanilla_LSTM
from aligner.cycle_gan_based_aligner import train_cycle_gan_aligner, test_cycle_gan_aligner
from aligner.latent_aligner import train_latent_cycle_gan_aligner, test_latent_cycle_gan_aligner

import random
import openpyxl
import wandb
import time


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('-cuda_device', type=str, default='0', help='which gpu to use ')
    parser.add_argument('-dataset', type=str, default='Jango', help='which dataset ')
    parser.add_argument('-batch_size', type=int, default=256)
    parser.add_argument('-src_train_ratio', type=float, default=0.8)
    parser.add_argument('-tgt_train_ratio', type=float, default=0.8)

    args = parser.parse_args()

    # dataset
    # Jango_2015_isometric_wrist_task, Spike_ISO_2012, Mihili_CO_2014, Chewie_CO_2016, Mihili_RT_2013_2014
    data_path = './datasets/Chewie_CO_2016/'
    NHP_id = 'Chewie'
    save_NHP_id = NHP_id
    # save_NHP_id = 'Mihili_RT'
    id_len, date_len = len(NHP_id), 8
    batch_size_limit = 10
    # l2 = 1 if NHP_id == 'Mihili' else 0
    mat_list = np.sort(fnmatch.filter(os.listdir(data_path), "*.mat")) # We sorted the files by name.

    args.dataset = save_NHP_id
    # args.dataset = 'Mihili_RT'
    dataset_config = get_dataset_config_class(args.dataset)()
    hyparams_config = get_hyparams_config_class(args.dataset)()
    device = torch.device("cuda:" + args.cuda_device) if torch.cuda.is_available() else torch.device('cpu')

    '''
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="behavior-oriented-alignment",
        
        # track hyperparameters and run metadata
        config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "CIFAR-100",
        "epochs": 10,
        }
    )
    '''
    
    # base_r2_score = [0.7399, 0.6262, 0.7077, 0.6588, 0.5700, 0.6155, 0.5550, 0.6598, 0.5526, -0.1689, 0.4826, 0.4572]
    # seed_list = [1, 2]
    # save_rel_path = './rel/behavior/ablation/spatial_temporal/spatial'

    para_list = [4, 7, 8]
    for para in para_list:
        hyper_name = 'window_size'
        dataset_config.window_size = para
        # dataset_config.domain_weight = para
        # dataset_config.domain_weight = para

        print('hyper-' + hyper_name + '=' + str(para))
        save_rel_path = './rel/behavior/hyper/' + hyper_name + '/' + str(dataset_config.window_size) + '/'
        # save_rel_path = './rel/behavior/disentangle/'
        save_rel_file = save_rel_path + 'be_dis_VAE_' + hyper_name + '_' + str(dataset_config.window_size) + '.xlsx'

        # save_rel_path = './rel/behavior/ablation/feedback/' 
        # save_rel_file = save_rel_path + 'be_dis_VAE_velocity_feedback.xlsx'

        # model_save_path = './model_checkpoints/ablation/spatial_temporal/spatial'
        model_save_path = './model_checkpoints/hyper/' + hyper_name + '/' + str(dataset_config.window_size) + '/'
        # model_save_path = './model_checkpoints/ablation/feedback/'
        # model_save_path = './model_checkpoints/'

        # disentanglement
        # Mihili: [1, 6, 7], Chewie: [0, 5, (10, 2)], Jango: [0, seed=0], Spike: [2], Mihili_RT: [s=1(0, 1, 2), s=2(1)]
        random_seed = hyparams_config.random_seed
        rs_list = hyparams_config.rs_list

        # random_seed = [0]
        # rs_list = [[0]]
        total_r2_score = []
        total_training_time = []
        for ri in range(len(random_seed)):
            for rs in rs_list[ri]:
                tgt_rel_date, tgt_test_score = [], []
                tgt_lya_max, tgt_lya_max_semantic, tgt_lya_max_specific = [], [], []
                tgt_training_time = []

                for tgt_mat_file in mat_list[0:]:
                    setup_seed(seed=random_seed[ri])
                    # seed_cur = 5
                    # setup_seed(seed=seed_cur)
                    
                    src_data_date, tgt_data_date = mat_list[0][id_len+1:id_len+date_len+1], tgt_mat_file[id_len+1:id_len+date_len+1]
                    print("source date: %s, target date: %s" % (src_data_date, tgt_data_date))
                    print("data preparing...")
                    bin_size, smooth_size = 0.05, 0.1
                    start_time = 'gocue_time'
                    shuffle_flag = False
                    
                    # 0: position 1: velocity 2: accelerated speed
                    cur_idx = 1
                    src_day_spike, src_day_cursor, src_day_unit_names = load_monkey_spike(data_path, src_data_date, bin_size, smooth_size, start_time, NHP_id)
                    src_day_cursor_pos_xy = src_day_cursor[cur_idx]

                    tgt_day_spike, tgt_day_cursor, tgt_day_unit_names = load_monkey_spike(data_path, tgt_data_date, bin_size, smooth_size, start_time, NHP_id)
                    tgt_day_cursor_pos_xy = tgt_day_cursor[cur_idx]

                    #============================================= Pre-processing ==================================================================#
                    # zero-padding empty channels
                    src_day_spike, tgt_day_spike = spike_preprocessing(src_day_unit_names, tgt_day_unit_names, src_day_spike, tgt_day_spike) 

                    # original space alignment
                    # from decoder.wiener_filter import train_src_decoder
                    n_lags = 4
                    # src_decoder_H = train_src_decoder(src_day_spike, src_day_cursor_pos_xy, n_lags, l2)

                    # Cycle-GAN aligner
                    #====================== These parameters controls the architecture of the discriminators =============================
                    D_params = {}
                    D_params['hidden_dim'] = int(src_day_spike[0].shape[1]*dataset_config.window_size)

                    #============================= These parameters controls the architecture of the generators =============================
                    G_params = {}
                    G_params['hidden_dim'] = int(src_day_spike[0].shape[1]*dataset_config.window_size)

                    #============================= These parameters are for the training process =============================
                    training_params = {}
                    training_params['loss_type'] = 'L1'
                    training_params['optim_type'] = 'Adam'
                    # Chewie: 250
                    training_params['epochs'] = 250 if NHP_id == 'Chewie' else 300
                    training_params['batch_size'] = 256
                    training_params['D_lr'] = 0.001*10
                    training_params['G_lr'] = 0.001
                    training_params['ID_loss_p'] = 5
                    training_params['cycle_loss_p'] = 5
                    training_params['drop_out_D'] = 0.2
                    training_params['drop_out_G'] = 0.2

                    # split training/test set
                    src_day_spike_train, src_day_spike_test, src_day_cursor_train, src_day_cursor_test = \
                        train_test_split(src_day_spike, src_day_cursor_pos_xy, test_size=1.0-args.src_train_ratio, random_state=rs)
                        # random_state=random_seed)
                    tgt_day_spike_train, tgt_day_spike_test, tgt_day_cursor_train, tgt_day_cursor_test = \
                        train_test_split(tgt_day_spike, tgt_day_cursor_pos_xy, test_size=1.0-args.tgt_train_ratio, random_state=rs)

                    if src_data_date == tgt_data_date:
                        tgt_day_spike_test, tgt_day_cursor_test = src_day_spike_test, src_day_cursor_test
                    
                    # transform spike data to window segments
                    src_day_win_seg, _ = monkey_spike_transform(src_day_spike, src_day_cursor_pos_xy, dataset_config.window_size)
                    tgt_day_train_win_seg, _ = monkey_spike_transform(tgt_day_spike_train, tgt_day_cursor_train, dataset_config.window_size)
                    
                    # training timing start
                    start_time = time.time()

                    # marginal distribution alignment: align target spikes (align p(x))
                    aligner = train_cycle_gan_aligner(src_day_win_seg, tgt_day_train_win_seg,
                                                      D_params, G_params, training_params,
                                                      args.cuda_device)

                    # initial stability
                    # align training tgt day spikes
                    '''
                    tgt_day_spike_train = test_cycle_gan_aligner(aligner, tgt_day_spike_train, args.cuda_device)

                    tgt_day_spike_test_ori = copy.deepcopy(tgt_day_spike_test)
                    if not src_data_date == tgt_data_date:
                        tgt_day_spike_test = test_cycle_gan_aligner(aligner, tgt_day_spike_test, args.cuda_device)
                    '''

                    src_train_generator = monkey_spike_generator(day_spike=src_day_spike_train, 
                                                                 day_cursor=src_day_cursor_train, 
                                                                 window_size=dataset_config.window_size,
                                                                 batch_size=args.batch_size,
                                                                 is_shuffle=shuffle_flag)
                    src_test_generator = monkey_spike_generator(day_spike=src_day_spike_test, 
                                                                day_cursor=src_day_cursor_test, 
                                                                window_size=dataset_config.window_size,
                                                                batch_size=args.batch_size,
                                                                is_shuffle=shuffle_flag)

                    tgt_train_generator = monkey_spike_generator(day_spike=tgt_day_spike_train,
                                                                 day_cursor=tgt_day_cursor_train,
                                                                 window_size=dataset_config.window_size,
                                                                 batch_size=args.batch_size,
                                                                 is_shuffle=shuffle_flag)
                    tgt_test_generator = monkey_spike_generator(day_spike=tgt_day_spike_test,
                                                                day_cursor=tgt_day_cursor_test,
                                                                window_size=dataset_config.window_size,
                                                                batch_size=args.batch_size,
                                                                is_shuffle=shuffle_flag)
                    
                    tgt_test_set_size = get_monkey_spike_size(day_spike=tgt_day_spike_test,
                                                              day_cursor=tgt_day_cursor_test,
                                                              window_size=dataset_config.window_size)

                    # decoder
                    '''
                    model = vanilla_model(input_dim=src_day_spike[0].shape[1], low_dim=hyparams_config.low_dim, 
                                        drop_out=hyparams_config.drop_prob, h_dim=hyparams_config.h_dim, 
                                        lstm_layer=hyparams_config.lstm_layer, window_size=dataset_config.window_size,
                                        pos_dim=dataset_config.pos_dim)
                    model = vanilla_LSTM(input_size=src_day_spike[0].shape[1], h_dim=hyparams_config.h_dim,
                                        lstm_layer=hyparams_config.lstm_layer, window_size=dataset_config.window_size,
                                        pos_dim=dataset_config.pos_dim)
                    '''
                    '''
                    model = VAE_Model(input_dim=src_day_spike[0].shape[1], low_dim=hyparams_config.low_dim, 
                                    drop_out=hyparams_config.drop_prob, h_dim=hyparams_config.h_dim, 
                                    lstm_layer=hyparams_config.lstm_layer, latent_dim = hyparams_config.latent_dim,
                                    pos_dim=dataset_config.pos_dim,
                                    kld_weight_rec=dataset_config.kld_weight_rec, kld_weight_pos=dataset_config.kld_weight_pos,
                                    rec_weight=dataset_config.rec_weight) 
                    
                    '''

                    # train source cursor predictor
                    target_dim = 2
                    cur_predictor = cursor_Predictor(input_size=src_day_cursor_train[0].shape[1], h_dim=hyparams_config.h_dim,
                                                    lstm_layer=hyparams_config.lstm_layer, pos_dim=target_dim)
                    cur_predictor.to(device)
                    # train_cursor_predictor(cur_predictor, device, src_day_cursor_train, dataset_config, hyparams_config)


                    model = disentangle_VAE_Model(input_dim=src_day_spike[0].shape[1], low_dim=hyparams_config.low_dim, 
                                                drop_out=hyparams_config.drop_prob, h_dim=hyparams_config.h_dim, 
                                                lstm_layer=hyparams_config.lstm_layer, latent_dim = hyparams_config.latent_dim,
                                                pos_dim=dataset_config.pos_dim,
                                                kld_weight_rec=dataset_config.kld_weight_rec, kld_weight_pos=dataset_config.kld_weight_pos,
                                                rec_weight=dataset_config.rec_weight,
                                                mse_weight=dataset_config.mse_weight, domain_weight=dataset_config.domain_weight)
                    model.to(device)
                    
                    # set optimizer
                    optimizer = torch.optim.Adam(model.parameters(), lr=hyparams_config.learning_rate,
                                                weight_decay=hyparams_config.weight_decay)
                    optimizer_ds = torch.optim.Adam(model.domain_classifier.parameters(), lr=hyparams_config.learning_rate,
                                                    weight_decay=hyparams_config.weight_decay)
                    optimizer_br = torch.optim.Adam(model.behavior_classifier.parameters(), lr=hyparams_config.learning_rate,
                                                    weight_decay=hyparams_config.weight_decay)
                    optimizer_cur = torch.optim.Adam(model.pos_read_out.parameters(), lr=hyparams_config.learning_rate,
                                                    weight_decay=hyparams_config.weight_decay)

                    global_step = 0
                    total_train_label_loss = 0.0
                    total_train_domain_loss = 0.0

                    best_score = -10.0
                    best_step = 0
                    ##################################################s
                    # dis_flag: if semantic latent subspace
                    dis_flag = True
                    ori_flag = True

                    # r2_score_func = R2Score(num_outputs=dataset_config.pos_dim) 
                    while global_step < hyparams_config.training_steps:
                        model.train()
                        src_train_batch_x, src_train_batch_y, src_train_batch_l = src_train_generator.__next__()

                        tgt_train_batch_x, tgt_train_batch_y, tgt_train_batch_l = tgt_train_generator.__next__()

                        if src_train_batch_y.shape[0] != tgt_train_batch_y.shape[0]:
                            continue
                        src_x = torch.tensor(src_train_batch_x).to(device)
                        src_y = torch.tensor(src_train_batch_y).to(device)
                        tgt_x = torch.tensor(tgt_train_batch_x).to(device)
                        # tgt_y = torch.tensor(tgt_train_batch_y).to(device)

                        # tgt_x alignment
                        if ori_flag:
                            tgt_x = test_cycle_gan_aligner(aligner, tgt_x, args.cuda_device)

                        # conditional distribution alignment (align q(z|x))
                        epoch_num = 1
                        for _ in range(epoch_num):   
                            '''
                            batch_y_pred, _, batch_total_loss = model.forward(src_x=src_x, src_y=src_y, train_flag=True)
                            optimizer.zero_grad()
                            batch_total_loss.backward()
                            optimizer.step()

                            # compute r2_score (before training)
                            src_pred_pos, src_gt_pos = batch_y_pred.detach().cpu().numpy(), src_y.detach().cpu().numpy()
                            r2_score_src = r2_score(src_pred_pos, src_gt_pos)
                            print("Total Loss = %f, R2 Score = %f"%(batch_total_loss, r2_score_src))
                            '''
                            # dis_flag: if semantic latent subspace
                            train_disentangle_VAE_Model(model=model, device=device,
                                                        src_x=src_x, src_y=src_y, tgt_x=tgt_x,
                                                        grl_weight=dataset_config.grl_weight,
                                                        hsic_weight=dataset_config.hsic_weight,
                                                        optimizer_VAE=optimizer, optimizer_ds=optimizer_ds, optimizer_br=optimizer_br,
                                                        dis_flag=dis_flag)

                            global_step += 1

                        # conditional distribution alignment: p(y|z)
                        # fine-tune decoder
                        # consistent trial structrue regularization
                        # train / test VAE latents
                        # for _ in range(cur_epoch_num):
                        # if global_step % hyparams_config.fine_tune_per_step == 0 and global_step != 0:
                            # tgt_day_spike_train_align = test_cycle_gan_aligner(aligner, tgt_day_spike_train, args.cuda_device)
                            # fine_tune_cursor_decoder(model, cur_predictor, device, tgt_day_spike_train_align, tgt_day_cursor_train, dataset_config, optimizer_cur)


                        # test phase
                        if global_step % hyparams_config.test_per_step == 0 and global_step != 0:
                        # if global_step == hyparams_config.training_steps:
                            '''
                            def get_latent_features(model, day_spike_train):
                                latent_train_list = []
                                for day_spike in day_spike_train:
                                    day_spike = day_spike[np.newaxis, :]
                                    day_spike_tensor = torch.tensor(day_spike).float().to(device)
                                    latent_train_tensor = model.extractor(day_spike_tensor)
                                    latent_train_list.append(latent_train_tensor[0].detach().cpu().numpy())
                                return latent_train_list
                            '''
                            # src_day_latent = get_latent_features(model, src_day_spike_train)
                            # aligner = train_latent_cycle_gan_aligner(src_day_spike_train, tgt_day_spike_train, D_params, G_params, training_params)
                            # aligner = train_cycle_gan_aligner(get_latent_features(model, src_day_spike_train), get_latent_features(model, tgt_day_spike_train),
                                                            # D_params, G_params, training_params)

                            total_tgt_test_label_loss = 0.0
                            total_tgt_test_r2_score = 0.0

                            tgt_test_epoch = int(math.ceil(tgt_test_set_size / float(args.batch_size)))
                            for _ in range(tgt_test_epoch):
                                model.eval()
                                with torch.no_grad():
                                    test_batch_tgt_x, test_batch_tgt_y, test_batch_tgt_l = tgt_test_generator.__next__()

                                    if test_batch_tgt_y.shape[0] < batch_size_limit:
                                        tgt_test_epoch -= 1
                                        continue

                                    test_x = torch.tensor(test_batch_tgt_x).to(device)
                                    test_y = torch.tensor(test_batch_tgt_y).to(device)
                                    if ori_flag:
                                        test_x = test_cycle_gan_aligner(aligner, test_x, args.cuda_device)

                                    batch_tgt_y_pred, _, batch_tgt_total_loss =  model.forward(src_x=test_x, src_y=test_y, src_flag=True, domain_flag=False, train_flag=False, dis_flag=dis_flag)
                                    
                                    total_tgt_test_label_loss += batch_tgt_total_loss.detach().cpu().numpy()
                                    # compute r2_score
                                    batch_tgt_pred_arr, batch_tgt_gt_arr = batch_tgt_y_pred.detach().cpu().numpy(), test_y.detach().cpu().numpy()
                                    r2_score_test = r2_score(batch_tgt_gt_arr, batch_tgt_pred_arr)
                                    total_tgt_test_r2_score += r2_score_test
                            
                            mean_tgt_test_label_loss = total_tgt_test_label_loss / tgt_test_epoch
                            # r2_score_test = r2_score_func(tgt_test_y_pred_list, tgt_test_y_true_list)
                            print("total loss = %f"%(mean_tgt_test_label_loss))
                            mean_tgt_test_r2_score = total_tgt_test_r2_score / tgt_test_epoch
                            print("total r2 score = %f"%(mean_tgt_test_r2_score))
                            delta_score = 5e-3
                            if mean_tgt_test_r2_score > best_score + delta_score:
                                best_score = mean_tgt_test_r2_score

                                # cursor position visualization
                                tgt_day_target_dir_test = find_target_dir(tgt_day_cursor_test, [-135, -90, -45, 0, 45, 90, 135, 180])
                                save_pred_fig_file = './fig/pred_cursor_pos.png'
                                save_gt_fig_file = './fig/gt_cursor_pos.png'

                                # load tgt day data
                                tgt_day_spike_test_format, tgt_day_cursor_test_format = get_all_input_spike(tgt_day_spike_test, tgt_day_cursor_test, dataset_config.window_size, shuffle_flag)
                                tgt_day_spike_test_format = torch.tensor(tgt_day_spike_test_format).to(device)
                                tgt_day_cursor_test_format = torch.tensor(tgt_day_cursor_test_format).to(device)

                                tgt_y_pred = get_pred_cursor_pos(model, tgt_day_spike_test_format, tgt_day_cursor_test_format)
                                if torch.cuda.is_available():
                                    tgt_y_pred = tgt_y_pred.detach().cpu().numpy()
                                    tgt_y_gt = tgt_day_cursor_test_format.detach().cpu().numpy()
                                plot_cursor_position(tgt_y_pred, tgt_day_spike_test, tgt_day_target_dir_test, dataset_config.window_size, save_pred_fig_file)

                                # gt cursor position visualization
                                plot_cursor_position(tgt_y_gt, tgt_day_spike_test, tgt_day_target_dir_test, dataset_config.window_size, save_gt_fig_file)

                                # lyapunov spectrum
                                # save NDS trajectories (latent space)
                                # tgt_day_spike_train_format, _ = get_all_input_spike(tgt_day_spike_test, tgt_day_cursor_test, dataset_config.window_size, shuffle_flag)
                                # tgt_day_spike_train_format = torch.tensor(tgt_day_spike_train_format).to(device)

                                tgt_low = model.read_in_src(tgt_day_spike_test_format)
                                src_hid, _  = model.encoder(tgt_low)
                                cell_state = extract_cell_states(
                                    model=model.encoder,
                                    x_input=tgt_low, hid_state=src_hid,
                                )

                                # calculate lyapunov spectrum
                                cell_state_np = torch.reshape(cell_state, shape=(-1, src_hid.shape[-1])).detach().cpu().numpy()
                                lya_max = lyapunov_solve_unknown(
                                    x=cell_state_np,
                                    step_sz=2e-2,
                                )
                                print('System maximum Lyapunov Exponent: ', lya_max)

                                # calculate sebantic Lyapunov spectrum
                                cell_state = cell_state.to(device)    
                                semantic_cell_state, specific_cell_state = extract_semantic_cell_state(
                                    model=model,
                                    cell_state=cell_state,
                                )

                                cell_state_semantic_np = torch.reshape(semantic_cell_state, shape=(-1, semantic_cell_state.shape[-1])).detach().cpu().numpy()
                                lya_max_semantic = lyapunov_solve_unknown(
                                    x=cell_state_semantic_np,
                                    step_sz=2e-2,
                                )
                                print('System maximum Semantic Lyapunov Exponent: ', lya_max_semantic)

                                cell_state_specific_np = torch.reshape(specific_cell_state, shape=(-1, specific_cell_state.shape[-1])).detach().cpu().numpy()
                                lya_max_specific = lyapunov_solve_unknown(
                                    x=cell_state_specific_np,
                                    step_sz=2e-2,
                                )
                                print('System maximum Specific Lyapunov Exponent: ', lya_max_specific)

                                # save the best model
                                model_save_dir = model_save_path + save_NHP_id + '/' + tgt_data_date + '/' + str(random_seed[ri])
                                if not os.path.exists(model_save_dir):
                                    os.makedirs(model_save_dir)
                                model_file = model_save_dir + '/behavior_disentangle_VAE'
                                torch.save({'GAN_aligner': aligner,
                                            'be_dis_VAE': model.state_dict(),
                                            'be_dis_VAE_model': model,
                                            'cur_predictor': cur_predictor.state_dict(),
                                            # after alignment p(x)
                                            'tgt_day_test_spike': tgt_day_spike_test,
                                            # 'tgt_day_test_spike_ori':tgt_day_spike_test_ori,
                                            'tgt_day_cursor_test': tgt_day_cursor_test,
                                            'tgt_y_pred': tgt_y_pred,
                                            'lya_max': lya_max,
                                            'lya_max_semantic': lya_max_semantic,
                                            'lya_max_specific': lya_max_specific,
                                            }, model_file)

                                # check structure
                                # struc_r2_score = get_cursor_position_structure(cur_predictor, device, tgt_y_pred, tgt_day_spike_test, dataset_config)                            print("best r2 score = %f"%(best_score))
                    
                    # training timing end
                    end_time = time.time()
                    training_time = (end_time - start_time)/hyparams_config.training_steps
                    tgt_training_time.append(training_time)

                    tgt_rel_date.append(tgt_data_date)
                    tgt_test_score.append(best_score)
                    tgt_lya_max.append(lya_max)
                    tgt_lya_max_semantic.append(lya_max_semantic)
                    tgt_lya_max_specific.append(lya_max_specific)
                
                # add results into total lists
                total_r2_score.append(tgt_test_score)
                total_training_time.append(tgt_training_time)
                
                # save results
                # wb = openpyxl.Workbook()
                # wb.create_sheet(title=save_NHP_id)
                if not os.path.exists(save_rel_path):
                    os.makedirs(save_rel_path)
                rel_file = Path(save_rel_file)
                if not rel_file.exists():
                    wb = openpyxl.Workbook()
                else:
                    wb = openpyxl.load_workbook(save_rel_file)

                # wb.create_sheet(title=NHP_id)
                if save_NHP_id not in wb.sheetnames:
                    wb.create_sheet(title=save_NHP_id)
                ws = wb[save_NHP_id]
                ws.append([])
                ws.append(['grl_weight = ' + str(dataset_config.grl_weight) + ', hsic_weight = ' + str(dataset_config.hsic_weight) + ', domain_weight = ' + str(dataset_config.domain_weight)])
                ws.append(['seed = ' + str(random_seed[ri]) + ', rs = ' + str(rs)])
                ws.append(tgt_rel_date)
                ws.append(tgt_test_score)
                ws.append(tgt_lya_max)
                ws.append(tgt_lya_max_semantic)
                ws.append(tgt_lya_max_specific)
                wb.save(save_rel_file)
        
        # save total score results
        total_r2_score = np.array(total_r2_score)
        r2_score_avg = np.mean(total_r2_score, axis=0)
        r2_score_std = np.std(total_r2_score, axis=0)
        training_time_avg = np.mean(total_training_time, axis=0)

        wb = openpyxl.load_workbook(save_rel_file)
        ws = wb[save_NHP_id]
        ws.append([])
        ws.append(tgt_rel_date)
        ws.append(list(r2_score_avg))
        ws.append(list(r2_score_std))
        ws.append(list(training_time_avg))
        wb.save(save_rel_file)        